{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_text(text, target_length=1057):\n",
    "    numbers = list(map(int, text.split()))\n",
    "    if len(numbers) < target_length:\n",
    "        numbers.extend([0] * (target_length - len(numbers)))  # 填充0\n",
    "    elif len(numbers) > target_length:\n",
    "        numbers = numbers[:target_length]  # 截断\n",
    "    return numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 1.2 转化成dataloader\n",
    "class LabelDataset(Dataset):\n",
    "    def __init__(self, label_sequences, multi_class_labels):\n",
    "        self.label_sequences = [torch.tensor(seq, dtype=torch.long) for seq in label_sequences]\n",
    "        self.multi_class_labels = multi_class_labels\n",
    " \n",
    "    def __len__(self):\n",
    "        return len(self.label_sequences)\n",
    " \n",
    "    def __getitem__(self, idx):\n",
    "        return self.label_sequences[idx], self.multi_class_labels[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. 定义模型\n",
    "class LSTMClassifier(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):\n",
    "        super(LSTMClassifier, self).__init__()\n",
    "        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_dim, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        embedded = self.embedding(x)  # 嵌入层\n",
    "        lstm_out, _ = self.lstm(embedded)  # LSTM 层\n",
    "        # 取 LSTM 的最后一个时间步的输出\n",
    "        last_output = lstm_out[:, -1, :]\n",
    "        output = self.fc(last_output)  # 全连接层\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# 参数设置\n",
    "vocab_size = 10000\n",
    "embedding_dim = 128\n",
    "hidden_dim = 20\n",
    "num_classes = 14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LSTMClassifier(\n",
       "  (embedding): Embedding(10000, 128)\n",
       "  (lstm): LSTM(128, 20, batch_first=True)\n",
       "  (fc): Linear(in_features=20, out_features=14, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 4. 模型存储与加载\n",
    "## 4.1 模型存储\n",
    "# torch.save(model.state_dict(), 'lstm_classifier.pth')\n",
    "\n",
    "## 4.2 模型加载\n",
    "# 定义模型结构\n",
    "model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, num_classes)\n",
    "# 加载模型状态字典\n",
    "model.load_state_dict(torch.load('lstm_classifier.pth'))\n",
    "# 将模型移动到 GPU（如果可用）\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5 模型预测\n",
    "## 5.1 预测一条数据\n",
    "test01 = '2673 5076 6835 2835 5948 5677 3247 4124 2465 5192 6101 913 4326 3615 442 4409 2466 6552 465 7399 1859 7449 5192 6101 913 4326 648 4490 7160 2400 3750 5915 2975 5393 7091 4659 4811 910 5731 7013 1264 3641 4939 1613 4659 2265 648 62 3080 3750 5998 7492 7509 3223 4893 3641 6453 1699 4939 2984 2717 6587 1068 2745 656 648 6122 1903 1736 3915 2515 4166 900 6143 19 2595 2673 5076 4562 137 1401 5998 2729 5028 4490 3750 803 1679 4811 5731 7013 4655 4124 5598 3750 6587 1068 2745 656 4998 3800 2990 623 3231 1635 2465 913 4326 6587 1068 1031 761 4939 2210 3961 151 1815 648 2522 761 3750 5470 5328 4939 5192 6101 913 4326 3750 6835 2835 3117 6898 6966 5430 4893 3585 5192 340 913 4326 3430 3310 1277 3750 1871 6929 5612 5430 648 2835 3310 2304 669 4659 5235 6333 3915 6909 3750 1394 7509 1571 6713 3585 6093 3529 5619 2974 2612 2456 5338 900 1460 7377 3961 5619 4893 2444 1815 1699 3750 5192 6101 913 4326 4939 1613 5235 6333 1866 3461 955 872 619 6835 2835 648 1816 5430 3750 742 25 3750 5037 619 4167 5410 6357 2614 299 5192 2835 1859 5430 2120 4289 5780 1043 265 6734 5235 6333 669 2304 5338 980 6017 980 6027 1736 3750 5948 1395 2087 730 6045 1460 5192 6101 913 4326 669 4409 2466 900 2465 5192 6101 913 4326 6835 2835 5948 5677 3247 4124 5780 1043 6027 1736 5192 6101 913 4326 2541 1460 619 1722 5305 913 4326 25 1457 3750 6637 6678 4128 3824 2028 4269 606 4683 5778 2109 6350 2087 730 6045 648 4117 5436 3750 3618 5192 6101 913 4326 5235 6333 1031 761 3750 4128 6122 6405 729 4448 7123 7509 3223 2087 730 6045 900 2465 7399 6656 2477 2400 2252 5780 3750 7509 3223 3893 4396 3504 4333 6038 4231 648 913 4326 1816 5430 648 4939 803 6043 4068 1622 1635 5192 2835 900 1460 619 2087 730 6045 6017 5028 3750 803 6043 4068 1622 1635 5192 2835 1394 4939 5235 6333 340 3893 383 648 6835 5330 6333 3720 3750 7399 1859 7449 648 4490 7160 2400 7091 5028 6966 5430 648 4939 2642 4909 6043 4068 1622 648 5192 5310 3750 1394 4124 4269 3270 5598 1906 3870 449 900 3618 4939 6122 4778 913 4326 1080 2614 3750 7509 3223 3961 6831 1394 3800 4151 4293 3000 7420 3523 1920 2313 2974 2612 5785 2866 6027 1736 3750 5998 4998 19 6508 6038 4231 299 2212 6630 5502 5598 900 2465 5051 5598 6453 3750 6357 2614 299 4811 2799 3370 7186 648 5192 6101 913 4326 3605 4333 6038 4231 1816 5430 648 4939 6043 4068 1622 5192 2835 3750 25 1722 4181 299 4969 6887 4811 3196 7370 3223 670 3961 5430 900 4811 7509 3223 1299 913 3605 1679 4939 1405 7123 803 6043 4068 1622 1635 648 6014 4396 3750 4811 648 1679 4811 2717 6980 5430 648 4939 2364 3189 648 5192 2835 900 2465 803 2364 3220 1699 5998 2729 648 5537 6040 1696 1696 2252 5620 3750 2087 730 6045 1460 5192 6101 913 4326 4355 4392 4409 3870 3750 6609 1903 2210 3961 1460 5192 6101 913 4326 3605 1854 3744 4040 6886 4480 3750 5948 1395 619 6833 2364 6966 5430 2642 4909 6662 2597 5491 6043 4068 1622 2662 4180 4464 433 5235 2835 648 913 4326 1299 6630 4659 5598 496 2042 3231 1635 1141 3120 3630 913 4326 1245 5915 3641 6983 3238 3686 2446 2968 1267 7254 3750 803 5998 4128 4939 1699 3373 7370 3106 4779 3961 6831 3860 5688 656 3257 913 4326 2597 4036 340 4287 5139 3750 25 1859 4237 3106 6444 151 2859 5780 2614 1734 913 4326 648 1736 3915 4462 3220 900 1635 2465 5731 7013 4939 2597 4036 4969 1279 2984 2717 6587 1068 2745 656 648 2515 4166 7399 1859 7449 5192 6101 913 4326 648 4490 7160 2400 3750 5915 2975 5393 7091 4659 4811 910 5731 7013 1264 3641 4939 1613 4659 2265 648 62 3080 3750 5998 7492 7509 3223 4893 3641 6453 1699 4939 2984 2717 6587 1068 2745 656 648 6122 1903 1736 3915 2515 4166 900 6143 19 2595 2673 5076 4562 137 1401 5998 2729 5028 4490 3750 803 1679 4811 5731 7013 4655 4124 5598 3750 6587 1068 2745 656 4998 3800 2990 623 3231 1635 2465 1460 7377 3750 1816 5430 6043 4068 1622 5192 2835 648 1141 3120 3630 913 4326 4105 1779 4369 1245 5915 3641 1702 3300 3750 7399 1859 7449 913 4326 648 4490 7160 2400 3750 3263 4411 913 1815 1906 648 803 3479 4261 3374 3893 1635 5338 5803 3750 5801 4939 1613 4811 1913 5422 648 4842 6725 2112 2612 5436 3750 3407 913 4326 648 5731 7013 4939 1613 6027 873 3750 5998 2489 3605 4939 1279 2975 2255 3915 648 900 3618 1465 3166 5445 6832 2376 3750 913 4326 5731 7013 4939 1613 4671 3646 1679 290 1702 2597 4036 383 6980 307 623 3750 25 1460 4822 4811 2990 623 6587 1068 2745 656 6980 648 5028 5949 3750 1465 5864 1702 3300 2087 730 6045 1080 4958 5598 7117 7305 900 2465'\n",
    "test01_list = list(map(int,test01.split()))\n",
    "test01_tensor = torch.tensor([test01_list], dtype=torch.long)\n",
    "# 注意：不带标签数据不用转换为dataset\n",
    "# dataset = LabelDataset(text_tensor, train['label'])\n",
    "dataloader = DataLoader(test01_tensor, batch_size=1)\n",
    "with torch.no_grad():\n",
    "    outputs = model(test01_tensor.to(device))\n",
    "    prob = torch.nn.functional.softmax(outputs, dim=1)\n",
    "    predicted_classes = prob.argmax(dim=1)\n",
    "    predicted_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[8]"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_classes.cpu().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. 模型预测\n",
    "test_text = test['text'].tolist()\n",
    "test_numbers = [process_text(item) for item in test_text]\n",
    "test_tensor = torch.tensor(test_numbers, dtype=torch.long)\n",
    "dataloader = DataLoader(test_tensor, batch_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = []\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        output = model(batch.to(device))\n",
    "        prob = torch.nn.functional.softmax(output, dim=1)\n",
    "        predicted_classes = prob.argmax(dim=1)\n",
    "        preds.append(predicted_classes.cpu().tolist()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 5.2. 结果保存\n",
    "test['label'] = preds\n",
    "test['label'].to_csv(\"/root/data/LSTM.csv\", index=False) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "countVctorRidgeclassification",
   "language": "python",
   "name": "countvctorridgeclassification"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
